package com.itmck.interceptor;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.itmck.interceptor.annotation.EncryptTransaction;
import com.itmck.interceptor.annotation.SensitiveData;
import com.itmck.util.DesEncryptThreeUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.jetbrains.annotations.Nullable;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.sql.PreparedStatement;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;

@Slf4j
@Component
@Intercepts({
        @Signature(type = ParameterHandler.class, method = "setParameters", args = PreparedStatement.class),
        @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})
})
public class DesEncryptInterceptor implements Interceptor {


    @Resource
    private DesEncryptFactory desEncryptFactory;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object target = invocation.getTarget();
        if (target instanceof ParameterHandler) {
            //在新增，更新的的时候对字段进行加密
            this.parameterHandlerDesEncryptMethod(invocation);
        } else if (target instanceof ResultSetHandler) {
            //对结果集的处理（解密操作）
            return this.resultSetHandlerEncryptMethod(invocation);
        }
        return invocation.proceed();
    }

    @Nullable
    private Object resultSetHandlerEncryptMethod(Invocation invocation) throws InvocationTargetException, IllegalAccessException {
        //取出查询的结果
        Object resultObject = invocation.proceed();
        if (Objects.isNull(resultObject)) {
            return null;
        }
        //基于selectList
        if (resultObject instanceof ArrayList) {
            ArrayList<?> resultList = (ArrayList<?>) resultObject;
            if (!CollectionUtils.isEmpty(resultList) && needToDecrypt(resultList.get(0))) {
                resultList.forEach(rel -> desEncryptFactory.decrypt(rel));
            }
            //基于selectOne
        } else {
            if (needToDecrypt(resultObject)) {
                desEncryptFactory.decrypt(resultObject);
            }
        }
        return resultObject;
    }

    private void parameterHandlerDesEncryptMethod(Invocation invocation) {
        ParameterHandler parameterHandler = (ParameterHandler) invocation.getTarget();
        // 获取参数对像，即 mapper 中 paramsType 的实例
        Object parameterObject = parameterHandler.getParameterObject();
        MetaObject delegateMetaObject = SystemMetaObject.forObject(parameterHandler);
        SqlCommandType sqlCommandType = (SqlCommandType) delegateMetaObject.getValue("sqlCommandType");
        //如果是新增操作
        if (SqlCommandType.INSERT.equals(sqlCommandType)) {
            //如果是新增操作
            log.info("当前sql执行类型为:{}", sqlCommandType);
            if (parameterObject != null) {
                delWithObj(parameterObject);
            }
            return;
        }
        //如果是更新操作
        if (SqlCommandType.UPDATE.equals(sqlCommandType)) {
            if (!(parameterObject instanceof Map)) {
                this.delWithObj(parameterObject);
                return;
            }
            BoundSql boundSql = (BoundSql) delegateMetaObject.getValue("boundSql");
            MappedStatement ms = (MappedStatement) delegateMetaObject.getValue("mappedStatement");
            Configuration configuration = ms.getConfiguration();
            List<String> encodeKeys = getObjectEncryptKeys(ms);
            if (CollectionUtil.isEmpty(encodeKeys)) {
                //没有需要加密的字段，直接结束
                return;
            }
            //下面处理的目的是为了兼容mybatis-plus的更新
            Map<Integer, String> propertyNameIndexMap = this.getSqlSegmentColumnIndex(boundSql.getSql())
                    .entrySet()
                    .stream()
                    .filter(ks -> isContains(encodeKeys, ks.getValue()))
                    .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            if (propertyNameIndexMap.isEmpty()) {
                return;
            }
            MetaObject metaObject = configuration.newMetaObject(parameterObject);
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            for (int i = 0; i < parameterMappings.size(); i++) {
                String propertyName = parameterMappings.get(i).getProperty();
                if (metaObject.hasGetter(propertyName)) {
                    Object obj = metaObject.getValue(propertyName);
                    String value = propertyNameIndexMap.get(i);
                    if (StringUtils.isNotBlank(value)) {
                        metaObject.setValue(propertyName, DesEncryptThreeUtil.encode3Des((String) obj));//加密
                    }
                }
            }
        }
    }

    private boolean isContains(List<String> encodeKeys, String str) {
        return StringUtils.isNotBlank(str) && (encodeKeys.contains(str) || encodeKeys.contains(str.toUpperCase()) || encodeKeys.contains(str.toLowerCase()));
    }

    //这里处理不是Map结构类型的
    private void delWithObj(Object parameterObject) {
        Class<?> parameterObjectClass = parameterObject.getClass();
        //校验该实例的类是否被@SensitiveData所注解
        SensitiveData sensitiveData = AnnotationUtils.findAnnotation(parameterObjectClass, SensitiveData.class);
        if (Objects.nonNull(sensitiveData)) {
            //取出当前当前类所有字段，传入加密方法
            Field[] declaredFields = parameterObjectClass.getDeclaredFields();
            desEncryptFactory.encrypt(declaredFields, parameterObject);
        }
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }


    /**
     * 获取到update语句中，每个字段的位置是在第几个
     *
     * @param sql update user set name=? , age=?, email=?  where id=?    替换==>  name?age?email?   切割==>  [name,age,email]
     * @return name 位置是1，age位置是2
     */
    private Map<Integer, String> getSqlSegmentColumnIndex(String sql) {
        sql = sql.replaceAll("\\s+", "");//获得带问号的sql语句
        Map<Integer, String> propertyNameIndexMap = new HashMap<>();
        String sqlSegment = sql.substring(sql.indexOf("SET") + "SET".length(), sql.indexOf("WHERE"))
                .replace("=", "")
                .replace(",", "")
                .trim();
        String[] split = sqlSegment.split("\\?");
        for (int i = 0; i < split.length; i++) {
            propertyNameIndexMap.put(i, split[i].trim());
        }
        return propertyNameIndexMap;
    }

    /**
     * 获取需加密的属性字段
     *
     * @param ms MappedStatement
     * @return 待加密的属性列表
     */
    private List<String> getObjectEncryptKeys(MappedStatement ms) {
        ParameterMap parameterMap = ms.getParameterMap();
        Class<?> type = parameterMap.getType();
        //获取到需要加密的字段
        List<String> kks = new ArrayList<>();
        SensitiveData sensitiveData = AnnotationUtils.findAnnotation(type, SensitiveData.class);
        // 如果是被注解的类，则进行加密
        if (Objects.nonNull(sensitiveData)) {
            Field[] declaredFields = type.getDeclaredFields();
            for (Field field : declaredFields) {
                //取出所有被EncryptTransaction注解的字段
                EncryptTransaction encryptTransaction = field.getAnnotation(EncryptTransaction.class);
                if (!Objects.isNull(encryptTransaction)) {
                    //暂时只实现String类型的加密
                    TableField tableField = field.getAnnotation(TableField.class);
                    if (null != tableField) {
                        String value = tableField.value();
                        if (StringUtils.isNotBlank(value)) {
                            kks.add(tableField.value());
                        }
                    } else {
                        //驼峰转下划线
                        kks.add(StrUtil.toUnderlineCase(field.getName()));
                    }

                }
            }
        }
        return kks;
    }

    private boolean needToDecrypt(Object object) {
        Class<?> objectClass = object.getClass();
        SensitiveData sensitiveData = AnnotationUtils.findAnnotation(objectClass, SensitiveData.class);
        return Objects.nonNull(sensitiveData);
    }
}
